import sys,os

project_root = os.path.abspath(os.path.join(os.path.dirname(__file__),'..'))
sys.path.append(project_root)

import torch
import witin_nn
import torch.nn as nn
import witin_nn.interface

class Config:
    def __init__(self, qat: bool = False, nat: bool = False):
        if qat==True:
            self.global_config = witin_nn.interface.GlobalConfigFactory.get_qat_train_wtm2100_config()
            self.layer_config=witin_nn.interface.LayerConfigFactory.get_default_config()
        elif nat==True:
            self.global_config = witin_nn.interface.GlobalConfigFactory.get_qat_nat_train_wtm2100_config()
            self.layer_config=witin_nn.interface.LayerConfigFactory.get_default_config()
        else :
            self.global_config = witin_nn.interface.GlobalConfigFactory.get_float_train_torch_config()
            # self.global_config = witin_nn.interface.GlobalConfigFactory.get_qat_train_wtm2100_config()
            self.layer_config=witin_nn.interface.LayerConfigFactory.get_default_config()
            
        self.layer_index=0
        print(self.global_config)
        # print(self.layer_config)

    def get_layer_config(self):
        layer_config = self.layer_config
        layer_config.index = self.layer_index
        self.layer_index+=1
        return layer_config

class ResidualBlock(nn.Module):
    def __init__(self, inchannel, outchannel, config: Config, stride=1):
        super(ResidualBlock, self).__init__()

        self.left1 = nn.Sequential(
            witin_nn.WitinConv2d(inchannel, outchannel, kernel_size=3, stride=stride, padding=1, bias=False, layer_config=config.get_layer_config()),
            witin_nn.WitinBatchNorm2d(outchannel, layer_config=config.get_layer_config()),
            witin_nn.WitinGELU(layer_config=config.get_layer_config())
        )
        self.left2 = nn.Sequential(
            witin_nn.WitinConv2d(outchannel, outchannel, kernel_size=3, stride=1, padding=1, bias=False, layer_config=config.get_layer_config()),
            witin_nn.WitinBatchNorm2d(outchannel, layer_config=config.get_layer_config()),
            witin_nn.WitinGELU(layer_config=config.get_layer_config())
        )
        self.shortcut = nn.Sequential()
        if stride != 1 or inchannel != outchannel:
            self.shortcut = nn.Sequential(
                witin_nn.WitinConv2d(inchannel, outchannel, kernel_size=1, stride=stride, bias=False, layer_config=config.get_layer_config()),
                witin_nn.WitinBatchNorm2d(outchannel, layer_config=config.get_layer_config()),
                witin_nn.WitinGELU(layer_config=config.get_layer_config())
            )
        self.gelu = witin_nn.WitinGELU(layer_config=config.get_layer_config())
        self.add = witin_nn.WitinElementAdd(layer_config=config.get_layer_config())

    def forward(self, x):
        out = self.left1(x)
        out = self.left2(out)
        shortcut = self.shortcut(x)
        out = self.add(out, shortcut)
        out = self.gelu(out)
        return out

class ResNet18(nn.Module):
    def __init__(self, num_classes=10, qat: bool = False,nat: bool = False):
        super(ResNet18, self).__init__()
        self.config = Config(qat=qat,nat=nat)
        self.inchannel = 64

        self.quantization_params = {}

        self.conv1 = nn.Sequential(
            witin_nn.WitinConv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False, layer_config=self.config.get_layer_config()),
            witin_nn.WitinBatchNorm2d(64, layer_config=self.config.get_layer_config()),
            witin_nn.WitinGELU(layer_config=self.config.get_layer_config())
        )
        self.layer1 = self.make_layer(ResidualBlock, 64, 2, stride=1)
        self.layer2 = self.make_layer(ResidualBlock, 128, 2, stride=2)
        self.layer3 = self.make_layer(ResidualBlock, 256, 2, stride=2)
        # self.layer4 = self.make_layer(ResidualBlock, 512, 2, stride=2)
        self.conv2 = nn.Sequential(
            witin_nn.WitinConv2d(256, 256, kernel_size=4, stride=4, bias=False, layer_config=self.config.get_layer_config()),
            witin_nn.WitinBatchNorm2d(256, layer_config=self.config.get_layer_config()),
            witin_nn.WitinGELU(layer_config=self.config.get_layer_config())
        )
        self.dropout = nn.Dropout(0.2)
        self.fc = witin_nn.WitinLinear(1024, num_classes, layer_config=self.config.get_layer_config())

    def make_layer(self, block, channels, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            layers.append(block(self.inchannel, channels, self.config, stride))
            self.inchannel = channels
        return nn.Sequential(*layers)

    def forward(self, x):
        out = self.conv1(x)
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        # out = self.layer4(out)
        out = self.conv2(out)
        out = self.dropout(out)
        out = out.view(out.size(0), -1)
        out = self.fc(out)
        return out

if __name__ == '__main__':
    model = ResNet18(num_classes=10)
    print(model)
    

